{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1.设置输入和输出节点的个数,配置神经网络的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_NODE = 784     # 输入节点。对于MNIST数据集就等于图片像素\n",
    "OUTPUT_NODE = 10     # 输出节点。类别数目，对于MNIST为0~9十个数字，即10类\n",
    "LAYER1_NODE = 500    # 隐藏层数。这里使用只有一个隐藏层的网络结构 \n",
    "                              \n",
    "BATCH_SIZE = 100     # 每次batch打包的样本个数。数据越小越接近随机梯度下降；数据越大越接近梯度下降   \n",
    "\n",
    "# 模型相关的参数\n",
    "LEARNING_RATE_BASE = 0.8       # 基础的学习率\n",
    "LEARNING_RATE_DECAY = 0.99     # 学习率的衰减率\n",
    "REGULARAZTION_RATE = 0.0001    # 正则项的系数\n",
    "TRAINING_STEPS = 5000          # 训练迭代的总轮数\n",
    "# MOVING_AVERAGE_DECAY = 0.99    # 滑动平均的衰减率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 定义辅助函数来计算前向传播结果，使用ReLU做为激活函数。\n",
    "一个辅助函数，给定神经网络的输入和所有参数，计算神经网络的前向传播结果。这里给定了一个使用ReLU激活函数的三层全连接神经网络，通过加入隐藏层实现多层。另外这个函数也支持传入用于计算参数平均值的类，便于在测试时使用滑动平均类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):\n",
    "    # 不使用滑动平均类\n",
    "    if avg_class == None:\n",
    "        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)\n",
    "        return tf.matmul(layer1, weights2) + biases2  # 因为在计算损失函数时会一并计算softmax函数，所以这里不需加入激活函数\n",
    "\n",
    "    else:\n",
    "        # 使用滑动平均类\n",
    "        layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))\n",
    "        return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 定义训练过程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    # 1. 定义神经网络参数，输入输出节点\n",
    "    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')\n",
    "    y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')\n",
    "    # 生成隐藏层的参数。\n",
    "    weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))\n",
    "    biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))\n",
    "    # 生成输出层的参数。\n",
    "    weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))\n",
    "    biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))\n",
    " \n",
    "    # 2. 前向传播、损失函数和反向传播\n",
    "    # a. 前向传播\n",
    "    # 计算不含滑动平均类的前向传播结果\n",
    "    y = inference(x, None, weights1, biases1, weights2, biases2)\n",
    "    \n",
    "#     # 计算使用滑动平均类的前向传播结果\n",
    "#     # 定义存储训练轮数的变量，不需要计算滑动平均值，因此指定为不可训练\n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "#     # 给定滑动平均衰减率、训练轮数，初始化滑动平均类。给定训练轮数可以加快训练早期变量的更新速度\n",
    "#     variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)\n",
    "#     # 在所有代表神经网络参数的变量上使用滑动平均：即Graphkeys.TRAINABLE_VARIABLES中的元素，也即所有没有指定trainable=False的参数\n",
    "#     variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "#     # 滑动平均不会改变变量本身取值，只是维护一个影子变量来记录其滑动平均值，因此需要使用这个滑动平均值时需要明确调用average函数\n",
    "#     average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)\n",
    "    \n",
    "    # b. 损失函数\n",
    "    # 计算交叉熵及其平均值，目标类别只有一个正确答案时可使用tf.nn.sparse_softmax_cross_entropy_with_logits来加速交叉熵计算\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, axis=1))\n",
    "    # 计算当前batch中所有样例的交叉熵平均值\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "    # 计算L2正则化损失，一般只计算权重部分，不使用偏置项\n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)\n",
    "    regularaztion = regularizer(weights1) + regularizer(weights2)\n",
    "    # 总的损失函数\n",
    "    loss = cross_entropy_mean + regularaztion\n",
    "    \n",
    "    # c. 反向传播\n",
    "    # 设置指数衰减的学习率。\n",
    "    learning_rate = tf.train.exponential_decay(\n",
    "        LEARNING_RATE_BASE,\n",
    "        global_step,\n",
    "        mnist.train.num_examples / BATCH_SIZE,\n",
    "        LEARNING_RATE_DECAY,\n",
    "        staircase=True)\n",
    "    # 优化损失函数（更新神经网络参数）\n",
    "    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n",
    "    # 每一遍训练，反向传播既需要更新参数，也需要更新每一个参数的滑动平均值，train_op = tf.group(train_step, variables_average_op)等价下面两行\n",
    "#     with tf.control_dependencies([train_step, variables_averages_op]):\n",
    "#         train_op = tf.no_op(name='train')\n",
    "    train_op = train_step\n",
    "\n",
    "    # 另；检验神经网络前向传播是否正确\n",
    "    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  # tf.equal判断两个张量的每一维是否相等，返回True/False\n",
    "#     correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))  # tf.equal判断两个张量的每一维是否相等，返回True/False\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # tf.cast转换数据类型，这里将布尔型转换成实数型\n",
    "    \n",
    "    # 3. 建立会话，训练模型\n",
    "    with tf.Session() as sess:\n",
    "        tf.global_variables_initializer().run()\n",
    "        # 准备验证数据，一般在神经网络训练过程中通过其来大致判断停止条件和评判训练的结果\n",
    "        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}\n",
    "        # 准备验证数据，在实际的应用中，这部分数据在训练时是不可见的，这个数据作为模型优劣的最终评价标准\n",
    "        test_feed = {x: mnist.test.images, y_: mnist.test.labels} \n",
    "        \n",
    "        # 循环地训练神经网络\n",
    "        for i in range(TRAINING_STEPS):\n",
    "            if i % 1000 == 0:\n",
    "                # 因MNIST数据集较小这里划分为更小的batch，当神经网络比较复杂或验证集比较大时，太大的batch会导致计算时间太长甚至发生内存溢出\n",
    "                validate_acc = sess.run(accuracy, feed_dict=validate_feed)\n",
    "                print(\"After %d training step(s), validation accuracy using average model is %g \" % (i, validate_acc))\n",
    "                \n",
    "                ##### 5.2.2节适用######\n",
    "#                 test_acc = sess.run(accuracy, feed_dict=test_feed)\n",
    "#                 print((\"After %d training step(s), test accuracy using average model is %g\" %(TRAINING_STEPS, test_acc)))\n",
    "                #######################\n",
    "            \n",
    "            xs, ys = mnist.train.next_batch(BATCH_SIZE)\n",
    "            sess.run(train_op, feed_dict={x:xs, y_:ys})\n",
    "\n",
    "        test_acc = sess.run(accuracy, feed_dict=test_feed)\n",
    "        print((\"After %d training step(s), test accuracy using average model is %g\" %(TRAINING_STEPS, test_acc)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4. 主程序入口，这里设定模型训练次数为5000次。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-5-7c47de531388>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../../../datasets/MNIST_data\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../../../datasets/MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../../../../datasets/MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../../../datasets/MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "After 0 training step(s), validation accuracy using average model is 0.0514 \n",
      "After 1000 training step(s), validation accuracy using average model is 0.974 \n",
      "After 2000 training step(s), validation accuracy using average model is 0.975 \n",
      "After 3000 training step(s), validation accuracy using average model is 0.978 \n",
      "After 4000 training step(s), validation accuracy using average model is 0.9824 \n",
      "After 5000 training step(s), test accuracy using average model is 0.9824\n"
     ]
    }
   ],
   "source": [
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"../../../../datasets/MNIST_data\", one_hot=True)\n",
    "    train(mnist)\n",
    "\n",
    "if __name__=='__main__':\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
